#! /usr/bin/env python3
###########################################################################
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
###########################################################################

import argparse
import sys

###########################################################################
#
# Define the steps.
#
# Each step in the compilation flow is defined by a list.  The first
# item in the list is a string which indicates the step type.  Most
# steps simply invoke a tool, such as clang, opt or llvm-link.  Other
# steps are sequences of steps.
#
# The interpretation of the remainder of the list depends on the step
# type as follows:
#   seq - The names of the steps in the sequence.
#   clang - N/A.
#   opt - Command line arguments to pass to opt.
#   link - The name of the LLVM-IR file to link.
#   hls_out - N/A.
step_specs = {
    # The top-level steps.
    'all': ('seq', 'front-end', 'back-end'),

    # Steps which use clang.
    'clang-xdp': ('clang',),
    'clang-nanotube': ('clang',),

    # Steps which use opt.
    'ebpf2nt': ('opt', '-codegenprepare -instsimplify -dce -ebpf2nanotube'),
    'mem2req' : ('opt', '-instsimplify -mem2req'),
    'inline'  : ('opt', '-always-inline -instsimplify'),
    'platform' : ('opt', '-platform -always-inline -instsimplify'),
    'control' : ('opt', '-control-capsule -always-inline'),
    'ntattr' : ( 'opt', '-nt-attributes -O2' ),
    'optreq' :  ( 'opt',
                  '-mergereturn '
                  '-optreq -enable-loop-unroll -always-inline '
                  '-instsimplify -loop-unroll -simplifycfg' ),
    'converge': ('opt',
                 '-move-alloca -compact-geps '
                 '-converge_mapa'),
    'inline_opt': ( 'opt',
                    '-always-inline'
                    ' -rewrite-setup -replace-malloc'
                    ' -thread-const -instsimplify'
                    ' -enable-loop-unroll -loop-unroll'
                    ' -move-alloca -simplifycfg -instcombine'
                    ' -thread-const -instsimplify -simplifycfg' ),
    'pipeline' : ( 'opt',
                   '-compact-geps -basic-aa -tbaa '
                   '-nanotube-aa -pipeline' ),
    'byteify' : ('opt', '-byteify'),
    'destruct': ('opt', '-destruct'),
    'flatten' : ('opt', '-flatten-cfg'),

    # Linking steps.
    'lower': ('link', 'nanotube_high_level.bc'),
    'link_taps': ('link', 'nanotube_low_level.bc'),

    # The HLS output step.
   'hls': ('hls_out',),

    # No steps.
    'none': ('seq',),
}

###########################################################################
#
# Define the front-ends and back-ends.
#

front_end_specs = {
    'xdp-c': (
        'seq',
        'clang-xdp',
        'ebpf2nt'
    ),
    'nanotube-c': (
        'seq',
        'clang-nanotube',
    ),
}

back_end_specs = {
    'hls': (
        'seq',
        'mem2req',
        'lower',
        'inline',
        'platform',
        'control',
        'ntattr',
        'optreq',
        'converge',
        'pipeline',
        'link_taps',
        'inline_opt',
        'hls',
    ),
}

###########################################################################
#
# Step classes.
#

class step_def:
    def __init__(self, name, args):
        self.__name = name
        self.__args = args

    def get_name(self):
        return self.__name

    def get_args(self):
        return self.__args

    def expand_step(self, the_app, name, n):
        return [self.__name]

class step_seq(step_def):
    def expand_step(self, the_app, name, n):
        assert(n >= 0), "Step loop detected in "+repr(name)
        result = []
        for step in self.get_args():
            result.extend(the_app.get_steps(step, n-1))
        return result

class step_clang(step_def):
    pass

class step_opt(step_def):
    pass

class step_link(step_def):
    pass

class step_hls_out(step_def):
    pass

###########################################################################
#
# The top-level application class.
#

class  app:
    def __init__(self, argv):
        # The command line parameters.
        self.__argv = argv
        # The parsed command line parameters.
        self.__args = None
        # The names of all the steps for the selected
        # front-end/back-end.
        self.__all_steps = []
        # The names of the selected steps.
        self.__steps = []
        # The dictionary which maps step names to step definitions.
        self.__step_defs = {}

    def parse_args(self):
        p = argparse.ArgumentParser()
        p.add_argument('-f', '--front-end', action="store",
                       help="Set the front-end type.")
        p.add_argument('-b', '--back-end', action="store",
                       help="Set the back-end type.")
        p.add_argument('-o', '--output', action="store",
                       help="Set the output filename.")
        p.add_argument('-p', '--project', action="store",
                       help="Set the project directory.")
        p.add_argument('-s', '--steps', action="store", default="all",
                       help="Set the steps to run.")
        p.add_argument('-l', '--list-steps', action="store_true",
                       help="List the steps to run.")
        p.add_argument('input', nargs='*',
                       help="The input files to process.")
        self.__args = p.parse_args(self.__argv[1:])

    def add_step_spec(self, name, spec):
        step_type = spec[0]
        step_class = globals().get("step_"+step_type)
        assert step_class is not None, "Unknown step type "+step_type
        step_args = spec[1:]
        step = step_class(name, step_args)
        self.__step_defs[name] = step

    def define_steps(self, arg, specs, name):
        """Used to define front-end and back-end steps based on the command
        line arguments.
        """

        if arg is None:
            return

        spec = specs.get(arg)
        if spec is None:
            sys.stderr.write("%s: Unknown %s %r.\n" %
                             (self.__argv[0], name, arg))
            sys.exit(1)

        self.add_step_spec(name, spec)
        self.__all_steps.extend(self.get_steps(name))

    def get_steps(self, name, n=None):
        """Get the sequence of steps with name <name>.
        """

        if n == None:
            n = len(self.__step_defs)

        step = self.__step_defs.get(name)
        if step is None:
            if name in ('front-end', 'back-end'):
                sys.stderr.write("%s: Cannot determine steps when the"
                                 " %s is unset.\n" %
                                 (self.__argv[0], name))
            else:
                sys.stderr.write("%s: Unknown step %r.\n" %
                                 (self.__argv[0], name))

            sys.exit(1)

        return step.expand_step(self, name, n)

    def find_step_index(self, seq_name, seq, step_name, dflt):
        """Find the specified step in a sequence of steps.
        """

        if step_name == "":
            return dflt

        try:
            return seq.index(step_name)
        except ValueError:
            pass

        # If the sequence name is "all" then add_step_range did not
        # expand the full list.  Do that here in case it generate an
        # error.
        if seq_name == "all":
            seq = self.get_steps(seq_name)

        sys.stderr.write("%s: Step %r is not in sequence %r.\n" %
                         (self.__argv[0], step_name, seq_name))
        sys.exit(1)

    def add_step_range(self, seq_name, start, end):
        """Add the specified step range to the sequence of steps to perform.
        """

        # If the sequence name is "all" then use the steps which have
        # been collected so far.  If that works then there's no need
        # to expand the full list.  If it doesn't work then
        # find_step_index will expand the full list in order to
        # generate the correct error message.
        if seq_name != "all":
            seq = self.get_steps(seq_name)
        else:
            seq = self.__all_steps

        start_idx = self.find_step_index(seq_name, seq, start, 0)
        end_idx = self.find_step_index(seq_name, seq, end, len(seq))

        self.__steps.extend(seq[start_idx:end_idx+1])

    def add_steps(self, spec):
        """Add the specified steps to the sequence of steps to perform.
        """

        parts = spec.split(":")
        if len(parts) == 1:
            self.__steps.extend(self.get_steps(parts[0]))
            return

        if len(parts) == 2:
            self.add_step_range("all", parts[0], parts[1])
            return

        if len(parts) == 3:
            self.add_step_range(parts[0], parts[1], parts[2])
            return

        sys.stderr.write("%s: Unknown step specification %r" %
                         (self.__argv[0], parts))

    def process_args(self):
        """Perform common processing to interpret the command line arguments.
        """

        # Set the common step definitions.
        for name,spec in step_specs.items():
            self.add_step_spec(name, spec)

        # Set the front-end and back-end steps.
        self.define_steps(self.__args.front_end, front_end_specs, "front-end")
        self.define_steps(self.__args.back_end, back_end_specs, "back-end")

        # Determine the steps to run.
        for spec in self.__args.steps.split(","):
            self.add_steps(spec)

    def list_steps(self):
        """Write the list of steps to stdout."""

        for step in self.__steps:
            print(step)
        sys.exit(0)

    def run(self):
        """Parse the command line arguments and perform the requested actions.
        """

        self.parse_args()
        self.process_args()
        if self.__args.list_steps:
            self.list_steps()
            return
        sys.stderr.write("%s: Only -l is supported.\n" %
                         (self.__argv[0],))

app(sys.argv).run()
